Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track generated functions for torch compile #1094

Merged
merged 2 commits into from
Nov 25, 2024

Conversation

Ch0ronomato
Copy link
Contributor

@Ch0ronomato Ch0ronomato commented Nov 17, 2024

Description

The torch compiler was having an issue running pytensor graphs with subgraphs. The way pytorch / module resolution worked was causing pytorch guards to fail during creation. Torch requires the guards to be created, and for them to be correct at runtime, so not being able to make one meant things couldn't get compiled. This tries to bandaid that by putting the modules we generate onto a module, essentially giving an explicit lifecycle to that module. This means we don't need to have graph breaks. That means faster compiler times, and better compiler results.

for this code

import numpy as np
import torch._dynamo as dynamo

from pytensor.configdefaults import config
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.tensor.type import matrices

x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y])
ofg_2 = OpFromGraph([x, y], [x * y, x - y])

o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) / o2

xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5

f = function([x, y, z], out, mode="PYTORCH")
print(f(xv, yv, zv))

with change

Graph Count: 1
Graph Break Count: 0
Op Count: 4
Break Reasons:
Ops per Graph:
  Ops 1:
    <built-in method multiply of type object at 0x128d5a9a0>
    <built-in method subtract of type object at 0x128d5a9a0>
    <built-in method add of type object at 0x128d5a9a0>
    <built-in method true_divide of type object at 0x128d5a9a0>
Compile Times: TorchDynamo compilation metrics:
Function, Runtimes (s)
Number of Out Guards: 58
_compile.<locals>.compile_inner, 0.0773
OutputGraph.call_user_compiler, 0.0006

without change

(pytensor-dev) ➜  pytensor git:(33a4d4882) ✗ python test.py
Graph Count: 4
Graph Break Count: 3
Op Count: 4
Break Reasons:
Ops per Graph:
  Ops 1:
    <built-in method multiply of type object at 0x11ed5a9a0>
  Ops 2:
    <built-in method subtract of type object at 0x11ed5a9a0>
  Ops 3:
    <built-in method add of type object at 0x11ed5a9a0>
  Ops 4:
    <built-in method true_divide of type object at 0x11ed5a9a0>
Number of Out Guards: 140
Compile Times: TorchDynamo compilation metrics:
Function, Runtimes (s)
_compile.<locals>.compile_inner, 0.0397, 0.0174, 0.0145, 0.0056, 0.0147, 0.0151
OutputGraph.call_user_compiler, 0.0007, 0.0000, 0.0000, 0.0000

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1094.org.readthedocs.build/en/1094/

@Ch0ronomato Ch0ronomato changed the title [WIP] Make lifetime of closure explicit (fingers crossed unit tests pass) [WIP] Make lifetime of closure explicit Nov 18, 2024
Copy link

codecov bot commented Nov 18, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.12%. Comparing base (6de3151) to head (aa6aac2).
Report is 134 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1094      +/-   ##
==========================================
- Coverage   82.12%   82.12%   -0.01%     
==========================================
  Files         183      183              
  Lines       47986    48016      +30     
  Branches     8644     8648       +4     
==========================================
+ Hits        39409    39433      +24     
- Misses       6411     6417       +6     
  Partials     2166     2166              
Files with missing lines Coverage Δ
pytensor/link/pytorch/dispatch/basic.py 94.49% <100.00%> (+0.05%) ⬆️
pytensor/link/pytorch/dispatch/blockwise.py 70.00% <100.00%> (-30.00%) ⬇️
pytensor/link/pytorch/linker.py 100.00% <100.00%> (ø)
pytensor/link/utils.py 59.75% <100.00%> (+0.24%) ⬆️

@Ch0ronomato Ch0ronomato changed the title [WIP] Make lifetime of closure explicit Make lifetime of closure explicit Nov 18, 2024
@Ch0ronomato Ch0ronomato marked this pull request as ready for review November 18, 2024 22:43
def conversion_func_register(*args, **kwargs):
functor = pytorch_funcify(*args, **kwargs)
module = pytensor.link.utils
setattr(module, kwargs["unique_name"](functor), functor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a weakref perhaps? So memory does get freed if nothing else is using these functions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that's a good idea. Let me try doing weakref.

Copy link
Contributor Author

@Ch0ronomato Ch0ronomato Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried doing weakref (ala weakref.ref) in a few spots and couldn't get it to work. Maybe more importantly, it's pretty intrusive (the generated code now has to know it's a weakref, and call it differently; all backends need weakref's, etc etc). I'm not sure this is a good path.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with this is PyTensor becomes a giant memory leak if you compile enough pytensor functions?

Can we get input from the torch devs now that we narrowed down the problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll definitely make an ask to the torch team to figure it out, and post here. Probably before we merge this I imagine?

To the point around giant memory leak, I'm not sure that'll be the case to be honest / I'm not particularly worried. These methods work as is if you disable torch compile, so that tells me that the closures are already somewhere. When profiling the script above with memray, I see that we have allocations / memory use within a normal wiggle room (+-5 allocs, 50kb). Here are the zips.
prof.zip

Copy link
Contributor Author

@Ch0ronomato Ch0ronomato Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the issue we end up seeing wrt to weakref

torch._dynamo.exc.InternalTorchDynamoError: weakly-referenced object no longer exists

from user code:
   File "/var/folders/98/g1t2_d2x4w94vqfv06xhyz6c0000gp/T/tmpnk75e974", line 3, in pytorch_funcified_fgraph
    tensor_variable_2, tensor_variable_3 = pytorch_funcified_fgraph(y, z)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to make the functiongraph that torch does keep around to be the one referening those closures? The link.utils will never go away unless someone does del pytensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can a __del__method in the torch linker, and have that cleaned up these explicit references?

Additionally yea I'll see if I cheese something where we have the global fgraph hold a reference somewhere....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to add a sort of wrapper execution function that can just allow us to clean up after ourselves. I do make an assumption that each pytensor.function def will have it's own PytorchLinker, lmk if that is not correct. This successfully cleaned up the references after the the new function goes out of scope.

Copy link
Member

@ricardoV94 ricardoV94 Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the function will have its own PyTorchLinker

self.gen_functors = copy.copy(gen_functors)

def __call__(self, *args, **kwargs):
import pytensor.link.utils
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, there is no way this is threadsafe.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks better. Can you add a test showing the extra attributes on the module being discarded when you del the function?

Also can we store them with a leading _? Not to clobber up during debugging?

@@ -34,3 +71,6 @@ def create_thunk_inputs(self, storage_map):
thunk_inputs.append(sinput)

return thunk_inputs

def record_fn(self, name, fn):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually removed this and moved the "conversion" func into the linker, so now the linker is doing all the tracking and stuff. Seemed more reasonable.

@ricardoV94
Copy link
Member

And more readable PR title?

@Ch0ronomato Ch0ronomato changed the title Make lifetime of closure explicit Track generated functions for torch compile Nov 23, 2024
@Ch0ronomato Ch0ronomato force-pushed the torch_compiler branch 2 times, most recently from 7631a68 to 07e6113 Compare November 23, 2024 05:35
@Ch0ronomato
Copy link
Contributor Author

Okay @ricardoV94 if this still looks good, it should be ready to merge.

@ricardoV94 ricardoV94 added enhancement New feature or request torch PyTorch backend maintenance and removed enhancement New feature or request labels Nov 25, 2024
@ricardoV94 ricardoV94 merged commit ae66e82 into pymc-devs:main Nov 25, 2024
62 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
maintenance torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants